import json
import networkx as nx
import matplotlib.pyplot as plt
from typing import Dict, Optional, Tuple

def load_knowledge_graph(json_file: str) -> Dict:
    with open(json_file, 'r', encoding='utf-8') as f:
        return json.load(f)

def get_category_colors() -> Dict[str, str]:
    colors = ['#edd2aa', '#ed726e', '#68e9e0', '#b3be66', '#7daa53', '#c75c58', '#717cf7', '#f3b26f', '#9b7fb3', '#65c3a1']
    return colors

def get_relation_colors() -> Dict[str, str]:
    colors = ['#e0ebf7', '#e6f1db', '#e5e5f9', '#f2e6f9', '#fdf3d0', '#f6d7d8', '#f0f4cc', '#b4b4e9']
    return colors

def get_base_relations():
    return ['is_a', 'part_of', 'leads_to', 'depends_on', 'similar_to', 'opposite_of', 'applies_to', 'composed_of']

# should use fixed relation when update graph rather than mapping here
# TODO: let graph extract fixed relation words
def get_base_relation(relation: str) -> str:
    mapping = {
        'action_on': 'applies_to',
        'acts_on': 'applies_to',
        'aims_to': 'applies_to',
        'analyzes': 'applies_to',
        'applies_to': 'applies_to',
        'directed_at': 'applies_to',
        'target_audience': 'applies_to',
        'targets': 'applies_to',
        
        'affects': 'leads_to',
        'binds': 'leads_to',
        'causes': 'leads_to',
        'enables': 'leads_to',
        'enhances': 'leads_to',
        'facilitates': 'leads_to',
        'implies': 'leads_to',
        'influences': 'leads_to',
        'leads_to': 'leads_to',
        'manipulates': 'leads_to',
        'promotes': 'leads_to',
        'promotes_action': 'leads_to',
        'shapes': 'leads_to',
        'transforms': 'leads_to',
        'triggers': 'leads_to',
        
        'contains': 'part_of',
        'includes': 'part_of',
        'includes_process': 'part_of',
        'part_of': 'part_of',
        
        'employs': 'depends_on',
        'used_by': 'depends_on',
        'used_for': 'depends_on',
        'used_in': 'depends_on',
        'used_with': 'depends_on',
        'uses': 'depends_on',
        
        'authorizes': 'composed_of',
        'avoids': 'composed_of',
        'certifies': 'composed_of',
        'creates': 'composed_of',
        'defends': 'composed_of',
        'discards': 'composed_of',
        'disguises': 'composed_of',
        'educates': 'composed_of',
        'enacts': 'composed_of',
        'engages_in': 'composed_of',
        'establishes': 'composed_of',
        'evades': 'composed_of',
        'exhibits': 'composed_of',
        'experiences': 'composed_of',
        'generates': 'composed_of',
        'guides': 'composed_of',
        'hides': 'composed_of',
        'hosts': 'composed_of',
        'ignores': 'composed_of',
        'implements': 'composed_of',
        'inspects': 'composed_of',
        'integrates': 'composed_of',
        'intends': 'composed_of',
        'obtains': 'composed_of',
        'operates': 'composed_of',
        'performed': 'composed_of',
        'performs': 'composed_of',
        'pretends_to_be': 'composed_of',
        'produces': 'composed_of',
        'protects': 'composed_of',
        'requests': 'composed_of',
        'reveals': 'composed_of',
        'revives_with': 'composed_of',
        'seeks': 'composed_of',
        'seeks_to_establish': 'composed_of',
        'stitches': 'composed_of',
        'submits_to': 'composed_of',
        'suggests': 'composed_of',
        'tests': 'composed_of',
        'traces': 'composed_of',
        'unites': 'composed_of',
        
        'demeans': 'opposite_of',
        'diverts_from': 'opposite_of',
        'exploits': 'opposite_of',
        'misleads_with': 'opposite_of',
        'negates': 'opposite_of',
        'opposes': 'opposite_of',
        'prevents': 'opposite_of',
        
        'depicts': 'similar_to',
        'describes': 'similar_to',
        'domain_specific': 'similar_to',
        'has_property': 'similar_to',
        'interacts_with': 'similar_to',
        'involves': 'similar_to',
        'misinformation_vector': 'similar_to',
        'occurs_during': 'similar_to',
        'occurs_on': 'similar_to',
        'parallel_execution': 'similar_to',
        'property_of': 'similar_to',
        'related_to': 'similar_to',
        'represents': 'similar_to',
        'synonym': 'similar_to',
        'synonym_of': 'similar_to',
        'weakness_of': 'similar_to',
        
        'controls': 'is_a',
        'is_a': 'is_a'
    }
    return mapping.get(relation, relation)

def get_relation_display_name(relation: str) -> str:
    display_names = {
        'is_a': 'Is A',
        'part_of': 'Part Of',
        'leads_to': 'Leads To',
        'depends_on': 'Depends On',
        'similar_to': 'Similar To',
        'opposite_of': 'Opposite Of',
        'applies_to': 'Applies To',
        'composed_of': 'Composed Of'
    }
    return display_names.get(relation, None)

def visualize_knowledge_graph(knowledge_graph: Dict, output_path: Optional[str] = None, title: str = "Knowledge Graph Visualization", figsize: Tuple[int, int] = (16, 12), node_size: int = 400, edge_width: float = 1.5, font_size: int = 8, show_labels: bool = True) -> plt.Figure:
    G = nx.DiGraph()
    categories = list(set(node['category'] for node in knowledge_graph['nodes'].values()))
    category_colors = get_category_colors()
    color_map = {cat: color for cat, color in zip(categories, category_colors)}
    
    base_relations = get_base_relations()
    relation_colors = get_relation_colors()
    relation_color_map = {rel: color for rel, color in zip(base_relations, relation_colors)}
    
    node_colors = []
    for node_id, node_data in knowledge_graph['nodes'].items():
        category = node_data.get('category', 'unknown')
        G.add_node(node_id, label=node_data.get('term', ''), category=category)
        node_colors.append(color_map.get(category, '#808080'))
    
    edge_colors = []
    edge_labels = {}
    for source_id, edges in knowledge_graph['edges'].items():
        for target_id, relation_type in edges:
            G.add_edge(source_id, target_id, relation=relation_type)
            base_relation = get_base_relation(relation_type)
            edge_colors.append(relation_color_map.get(base_relation, '#808080'))
            edge_labels[(source_id, target_id)] = relation_type
    
    fig, ax = plt.subplots(figsize=figsize)
    pos = nx.spring_layout(G, k=0.5, iterations=50)
    
    nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=node_size, alpha=0.8, ax=ax)
    nx.draw_networkx_edges(G, pos, edge_color=edge_colors, width=edge_width, arrowsize=15, arrowstyle='->', alpha=0.7, ax=ax)
    
    if show_labels:
        labels = {node_id: G.nodes[node_id]['label'] for node_id in G.nodes}
        nx.draw_networkx_labels(G, pos, labels=labels, font_size=font_size, font_weight='bold', ax=ax)
    
    category_legend = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=10, label=cat)
                      for cat, color in color_map.items()]
    
    relation_legend = [plt.Line2D([0], [0], color=relation_color_map[rel], label=get_relation_display_name(rel), linewidth=2)
                      for rel in base_relations]
    
    l1 = ax.legend(handles=category_legend, loc='upper left', title="Categories")
    ax.add_artist(l1)
    ax.legend(handles=relation_legend, loc='upper right', title="Relations")
    
    plt.title(title, fontsize=16)
    plt.axis('off')
    
    if output_path:
        plt.savefig(output_path, bbox_inches='tight', dpi=300)
    
    return fig

def get_all_relations(knowledge_graph):
    relations = set()
    for source_edges in knowledge_graph['edges'].values():
        for _, relation_type in source_edges:
            relations.add(relation_type)
    relations_list = sorted(list(relations))
    return relations_list

import os
def main():
    project_root = os.path.dirname(__file__)
    json_file = os.path.join(project_root, "output", "ds", "jbb", "knowledge_graph", "formalization_kg.json")
    knowledge_graph = load_knowledge_graph(json_file)
    print(get_all_relations(knowledge_graph))

    output_path = os.path.join(project_root, "output", "ds", "knowledge_graph.png")
    
    fig = visualize_knowledge_graph(
        knowledge_graph,
        output_path=output_path,
        title="PASS Knowledge Graph",
        figsize=(20, 10),
        node_size=100,
        font_size=2
    )
    
    plt.show()

if __name__ == "__main__":
    main()
